'''
Created on 29-Mar-2013

@author: cdac
'''

import csv
from random import sample

class DataSource:

    def __init__(self, file_source):
        self.__file_source = file_source
        self.__msg_list = list()
        msgfd   = open(self.__file_source, 'rb')
        msgcsv  = csv.reader(msgfd, delimiter='\t')
        for msg in msgcsv:
            self.__msg_list.append(msg[1])
        

    def get_all_msgs(self):
        '''
        Get ham/spam msgs of format from file =>
            Spam \t Msg
            Ham \t Mg
        '''
        return self.__msg_list


    def get_first_msgs(self, how_many):
        '''
        Get first 'how_many' ham/spam msgs of format from file =>
            Spam \t MSg
            Ham \t MSg
        '''
        first_msgs = list()
        for msg in self.__msg_list:
            first_msgs.append(msg)
            how_many -= 1
            if (how_many <= 0):
                break

        return first_msgs


    def get_random_msgs(self, how_many=1):
        random_msgs = sample(self.__msg_list, how_many)
        return random_msgs
        
    def get_msgs_in_range(self, from_msg_sample, to_msg_sample):
        msgs_in_range_list = list()

        for index in range(from_msg_sample, to_msg_sample): 
            msgs_in_range_list.append(self.__msg_list[index])
        return msgs_in_range_list
        
if __name__ == '__main__':
    
    small_ham_file  = '../Data/HamSmall.csv'
    small_spam_file = '../Data/SpamSmall.csv'

    large_ham_file  = '../Data/Ham.csv'
    large_spam_file = '../Data/Spam.csv'

    
    # ds = DataSource('../Data/Ham.csv', '../Data/Spam.csv')
    ds_hams = DataSource(small_ham_file)
    ds_spams = DataSource(small_spam_file)
    
    hams    = ds_hams.get_first_msgs(2) 
    spams   = ds_spams.get_first_msgs(2)   
    print hams, len(hams)
    print spams, len(spams)
    
    hams = ds_hams.get_random_msgs(how_many=5)
    spams = ds_spams.get_random_msgs(how_many=5)
    print hams, len(hams)
    print spams, len(spams)
    
    hams = ds_hams.get_msgs_in_range(0, 10)
    spams = ds_spams.get_msgs_in_range(0, 5)
    print hams, len(hams)
    print spams, len(spams)
    
    